# -*- coding: utf-8 -*-  
'''
训练LDA模型

Created on 2021年8月22日
@author: luoyi
'''
import sys
import os
#    取项目根目录
ROOT_PATH = os.path.abspath(os.path.dirname(__file__)).split('tbert')[0]
ROOT_PATH = ROOT_PATH + "tbert"
sys.path.append(ROOT_PATH)

import numpy as np
np.set_printoptions(suppress=True)

from utils.iexicon import LiteWordsWarehouse
from data.sohu_thuc_news.lda_gsdmm_dataset import LdaGSDmmPreDataset
from models.lda.nets_np import LDA


#    准备词库
LiteWordsWarehouse.instance().load_pkl()

#    准备数据集
lda_ds = LdaGSDmmPreDataset()


#    准备模型
lda = LDA()
# #    继续训练
# lda.load_weight()
lda.training(lda_ds, epochs=600)